'''
Author: Simon
Date: 2022-03-10 15:41:02
LastEditTime: 2022-03-10 16:20:24
LastEditors: Please set LastEditors
Description: 获取 Dataset 中每一个图片的路径和 Dataset Length
FilePath: \pytorch-tutorial\LoadingData_DataVisualization\read_data.py
'''
from torch.utils.data import Dataset
from PIL import Image
import os


class MyData(Dataset):
    def __init__(self, root_dir, label_dir) -> None:
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
        
    
    '''
        getitem 获取每一个图片
    '''
    def __getitem__(self, index):
        img_name = self.img_path[index]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    
    
    '''
        获取 Dataset 长度
    '''
    def __len__(self):
        return len(self.img_path)
    

root_dir = "dataset\\hymenoptera_data\\hymenoptera_data\\train"

ants_label_dir = "ants"
bees_label_dir = "bees"

ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset
